DAG Nammit

The Challenges and Dangers of Causally Interpreting Machine Learning Models

Brandon M. Greenwell

Slides: https://github.com/bgreenwell/dagnammit

About me

Why does explainability matter?

  • Model debugging - Why did my model Netflix make this mistake?
  • Feature Engineering - How can I improve my model?
  • Detecting fairness issues - Does my model discriminate?
  • Human-AI cooperation - How can I understand and trust the model’s decisions?
  • Regulatory compliance - Does my model satisfy legal requirements?
  • High-risk applications - Healthcare, finance, judicial, …
  • Common sense

Interpretability in a nutshell 🥜

  • InterpretabilityExplainability
  • Global vs. local explainability
  • ◼️Black-box vs. 🪟glass-box models
  • ⚠️ Model-agnostic ⚠️ vs. model-specific techniques
  • Multicollinearity is the nemesis of interpretability!
  • 🧑‍💻 Lots of good software!

👉

Useful resources

So what’s the problem(s) with causally interpreting machine learning models?

Important

Machine learning is often applied to observational or happenstance data!

Correlation doesn’t imply causation 🙄

Some common causal fallacies:

Some causal fallacies in the wild

Customer retention example

  • Initial goal is to train a model to predict whether a customer will renew their software subscription (taken from Lundberg et al. (2021))
  • Eight features were identified for predicting retention (Did.renew=0/1):
    1. Customer discount offered upon renewal (Discount)
    2. Ad spending on this type of customer since last renewal (Ad.spend)
    3. Customer’s monthly usage (Monthly.usage)
    4. Time since last upgrade upon renewal (Last.upgrade)
    5. No. bugs reported by customer since last renewal (Bugs.reported)
    6. No. interactions with customer since last renewal (Interactions)
    7. No. sales calls with customer since last renewal (Sales.calls)
    8. Health of regional economy upon renewal (Economy)
  • 10k total records: 8k for training and 2k for validation

Retention example (cont.)

Output from an additive logistic regression fit:

              Estimate Std. Error z value Pr(>|z|)
(Intercept)     -0.665      0.134  -4.961    0.000
Sales.calls      0.074      0.060   1.238    0.216
Interactions     0.091      0.056   1.612    0.107
Economy          0.597      0.091   6.589    0.000
Last.upgrade    -0.022      0.005  -4.190    0.000
Discount        -5.950      0.311 -19.106    0.000
Monthly.usage    0.351      0.146   2.406    0.016
Ad.spend         0.602      0.062   9.766    0.000
Bugs.reported    0.259      0.035   7.345    0.000

Tip

Checking variance inflation factors (VIFs) is always a good idea, even for black-box models!

Retention example (cont.)

Pearson correlation matrix:

Retention example (cont.)

Variable importance scores from an XGBoost fit:

Partial dependence (PD) plots

Interpreting the PD plots

  • Ad.spend and Discount are important to this (fictional) business because they can be directly manipulated 🎛️
  • 🙌🎉🥳 Hurrah! We can improve retention by
    • ⬆️ Increasing ad spend
    • ⬇️ Decreasing discount amount

NOT SO FAST!!!

The true data generator

\[ \begin{aligned} \mathsf{logit}\left(p\right) = 1.26 &\times \mathtt{Product.need} + \\ 0.56 &\times\mathtt{Monthly.usage} + \\ 0.7 &\times \mathtt{Economy} + \\ 0.35 &\times \mathtt{Discount} + \\ 0.35 &\times \left(1 - \mathtt{Bugs.faced} / 20\right) + \\ 0.035 &\times \mathtt{Sales.calls} + \\ 0.105 &\times \mathtt{Interactions} + \\ 0.7 &\times \left(\mathtt{Last.upgrade} / 4 + 0.25\right)^{-1} + \\ 0 &\times \mathtt{Ad.spend} + \\ &-3.15 + \epsilon\ \end{aligned} \]

Partial dependence vs. truth! 😱

PD plot (black) vs. true causal relationship (red)

Even the experts slip up!

Statistical Learning with Big Data (fantastic talk!)

So now what?


Causal interpretation requires a causal model!!


📺 Watch the first talk by Peter Tennant!

Directed asyclic graphs (DAGs)

  • Useful for representing causal relationships and assumptions
    • Directed: One-sided arrows (→) connect (assumed) causes and effects
    • Asyclic: no directed path can form a closed loop
  • Help determine whether the effect(s) of interest can be estimated from available data
  • Based on strong assumptions that are often unverifiable

DAGs in machine learning

Assume we have five features (X1X5) and a response (Y). Causally interpreting a machine learning model assumes a very particular DAG!

How your algorithm sees it:

flowchart TB
  X1 --> Y
  X2 --> Y
  X3 --> Y
  X4 --> Y
  X5 --> Y

How the universe works:

flowchart TB
  X1 --> X3
  X1 --> Y
  X2 --> X3
  X2 --> Y
  X3 --> X4
  X3 --> Y
  X4 --> Y
  X5 --> Y

Estimation and confounding

  • In causal inference, a common goal is to estimate the average (caual) effect of some “treatment” on an outcome of interest (e.g., effect of an ad campaign on sales)

  • Estimation typically requires adjusting (and not adjusting) for certain variables

  • A confounder is a variable that effects both the treatment and outcome

    • Confounders must be identified, measured, and appropriately adjusted for in the analysis
  • Need to be careful with other covariate roles, like colliders, mediators, etc.

Adjustment sets are key 🔑

Minimal sufficient adjustment set for estimating

  • Total effect of X3 on Y: {X1, X2}
  • Direct effect of X3 on Y: {X1, X2, X4}

flowchart TB
  X1 --> X3
  X1 --> Y
  X2 --> X3
  X2 --> Y
  X3 --> X4
  X3 --> Y
  X4 --> Y
  X5 --> Y

Tip

Tools like DAGitty can help automate this!

Copy and paste this code into DAGitty

dag {
bb="0,0,1,1"
X1 [pos="0.462,0.332"]
X2 [pos="0.425,0.238"]
X3 [exposure,pos="0.532,0.277"]
X4 [pos="0.529,0.396"]
X5 [pos="0.363,0.416"]
Y [outcome,pos="0.439,0.464"]
X1 -> X3
X1 -> Y
X2 -> X3
X2 -> Y
X3 -> X4
X3 -> Y
X4 -> Y
X5 -> Y
}

Useful resources

Retention example (cont.)

Assume strong domain expertise has allowed us to generate the following DAG:

%3 Bugs reported Bugs reported Monthly usage Monthly usage Ad spend Ad spend Monthly usage->Ad spend Bugs faced Bugs faced Monthly usage->Bugs faced Did renew Did renew Monthly usage->Did renew Sales calls Sales calls Interactions Interactions Sales calls->Interactions Product need Product need Sales calls->Product need Sales calls->Did renew Economy Economy Economy->Did renew Discount Discount Discount->Did renew Last upgrade Last upgrade Last upgrade->Ad spend Last upgrade->Did renew Interactions->Did renew Product need->Bugs reported Product need->Monthly usage Product need->Discount Product need->Did renew Bugs faced->Bugs reported Bugs faced->Did renew

Causal Interpretations of Black-Box Models


Mathematical background

The partial dependence (PD) of \(Y\) on \(X_S\) is defined as

\[ \begin{aligned} g_s\left(x_s\right) &= E_{X_c}\left[g\left(x_s, X_c\right)\right] \\ &\approx \frac{1}{N}\sum_{i=1}^N g\left(x_S, X_{iC}\right) \end{aligned} \]

Retention example (cont.)

PD of Did.renew on Ad.spend, adjusted for only Monthly.usage and Last.upgrade: … 🥁

Ummm … maybe a case of estimand vs. estimate? 🤔

Stop permuting features?! 😱

…PaP metrics can vastly over-emphasize correlated features in both variable importance measures and partial dependence plots.

Retention example (cont.)

Double/debiased machine learning

Given a causal model, double ML ⚠️essentially⚠️ involves three steps:

  1. Predict the outcome (\(y\)) from an appropriate adjustment set and get the residuals (\(r_y\))

  2. Predict the treatment (\(x\)) from the same adjustment set and get the residuals (\(r_x\))

  3. Regress \(r_y\) on \(r_x\) to create a model of the average causal effect (i.e., the slope)

Double ML for Ad.spend

dml_data <- DoubleML::DoubleMLData$new(
  data = ret.trn,                              # training data
  y_col = "Did.renew",                         # response
  d_cols = "Ad.spend",                         # treatment
  x_cols = c("Last.upgrade", "Monthly.usage")  # adjustment set
)
lrnr <- mlr3::lrn("regr.ranger", num.trees = 500)
set.seed(1810)  # for reproducibility
dml_plr = DoubleML::DoubleMLPLR$new(
  dml_data, ml_l = lrnr$clone(), ml_m = lrnr$clone()
)
dml_plr$fit()
# Print results
print(dml_plr)
# ------------------ Fit summary       ------------------
#   Estimates and significance testing of the effect of target variables
#          Estimate. Std. Error t value Pr(>|t|)
# Ad.spend  -0.09634    0.25197  -0.382    0.702

# Compute 95% confidence interval
print(dml_plr$confint())
#               2.5 %   97.5 %
# Ad.spend -0.5901917 0.397511

Designed experiments

  • RCTs are arguably still the gold standard, but …
    • 😇 There can be ethical concerns
    • 💰 Can be expensive to implement

However…

Tip

Responsible, transparent use of machine learning can help narrow down the hypothesis space!

Ingot cracking example

I’m reminded of an old (but still fantastic) data mining lecture from Richard De Veaux (skip to the 44:30 mark)

  • 20,000 lb. ingots made in a giant mold
  • Roughtly 25% of ingots develop cracks
  • Cracked ingots cost $30,000 to recast
  • Roughly 900 observations (ingots) on 149 variables
  • What’s causing them to crack?

Ingot cracking example (cont.)

  • Lots of iterations, but… “Looks like Chrome(!?)”
  • 🕵️ A glass-box model gave clues for generating a hypothesis (i.e., which variable to focus on)
  • Follow-up randomized experiments led to substantial improvement!

Adding constraints (where feasible)

  • Often useful to constrain the functional form of the model in some way

    • Business considerations
    • Domain knowledge
  • Enforcing sparsity (e.g., EBMs with Sparsity)

  • Enforcing monotonicty between features and the predicted output can be done in several ways during training (e.g., linear and tree-based models)

Pneumonia example

  • Data contains 46 features on 14199 pneumonia patients
    • Patient demographics (e.g., age)
    • 📐 Various measuremnts (e.g., heart rate)
    • 🔬 Lab test results (e.g., WBC)
    • 🩻 Chest x-ray results (e.g., pleural effusion)
  • Goal is to predict probability of death (0/1) using a GA2M
  • Data from Caruana et al. (2015) and Wang et al. (2022)

Pneumonia example (cont.)

Living past 100 decreases risk? 🫤

Pneumonia example (cont.)

Adding monotonic constraints can be helpful!

Pneumonia example (cont.)

Having asthma lowers a patient’s risk of dying from pneumonia? 🤯

Pneumonia example (cont.)

According to the doctors, asthmatic patients (A) would likely receive better care earlier (T):

G A A R R A->R T T A->T T->R

Pneumonia example (cont.)

  • If we use the model as is to make hospital admission decisions, asthmatic patients are likely to miss out on care they need
  • Interpretability and causal knowledge can help identify such dangerous patterns and improve the model:
    • Force monotonicity (e.g., asthmatic > non-asthmatic)
    • Remove the asthma feature
    • Edit the effect out 😱 (e.g., using GAM Changer)

GAM Changer

Causal discovery? 🤔

🔑 Key takeaways

  • Machine learning models are great at identifying and utilizing patterns and associations in the data to make predictions

  • Causal knowledge can be used to improve these models!

  • Some quotes I like from Becoming A Data Head:

“There are clever ways to use observational data to suggest some causal relationships. [They ALL] rely on strong assumptions and clever statistics.”

“Any claims of causality with observational data should be met with skeptimicism.” [(ANY!!)]

Questions? 🙋

Source: xkcd comic